""" DiffGro Policies Implementation """
from functools import partial
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import gym
import jax
import jax.numpy as jnp
import haiku as hk
import optax

from sb3_jax.common.policies import BasePolicy
from sb3_jax.common.norm_layers import BaseNormLayer
from sb3_jax.common.jax_layers import BaseFeaturesExtractor, FlattenExtractor
from sb3_jax.common.preprocessing import get_flattened_obs_dim, get_act_dim
from sb3_jax.common.type_aliases import GymEnv, MaybeCallback, Schedule
from sb3_jax.common.utils import get_dummy_decision_transformer, get_dummy_obs, get_dummy_act
from sb3_jax.du.policies import DiffusionBetaScheduler

from diffgro.utils.utils import print_b 
from diffgro.common.models.helpers import MLP
from diffgro.common.models.utils import sample_dist
from diffgro.common.models.diffusion import UNetDiffusion, Diffusion
from diffgro.common.models.vae import MLPEncoder, LSTMEncoder, DiffusionVAE
from diffgro.diffgro.functions import calculate_grad


@jax.jit
def apply_mask(mask: jax.Array, x_t: jax.Array, cond: jax.Array):
    return (1 - mask) * x_t + mask * cond


class Actor(BasePolicy):
    """ actor class for diffgro policy """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        net_arch: Optional[List[int]] = None,
        activation_fn: str = 'mish',
        # embedding
        horizon: int = 8,       # planning horizon
        skill_dim: int = 512,   # semantic skill embedding dimension
        emb_dim: int = 64,      # embeddign dimension
        hid_dim: int = 128,     # lstm hidden dimension
        ctx_dim: int = 64,      # skill embedding dimension
        # diffusion
        n_denoise: int = 20,    # denoising timestep
        cf_weight: float = 1.0, # classifier-free guidance weight
        predict_epsilon: bool = False,  # predict noise / original
        beta_scheduler: str = 'linear', # denoising scheduler
        seed: int = 1,
    ):
        super(Actor, self).__init__(
            observation_space,
            action_space,
            squash_output=False,
            seed=seed,
        )

        self.net_arch = net_arch
        self.activation_fn = activation_fn
        
        # embedding
        self.horizon = horizon
        self.skill_dim = skill_dim
        self.emb_dim = emb_dim
        self.hid_dim = hid_dim
        self.ctx_dim = ctx_dim
        
        # diffusion
        self.n_denoise = n_denoise
        self.cf_weight = cf_weight
        self.predict_epsilon = predict_epsilon
        self.beta_scheduler = beta_scheduler
        self.ddpm_dict = DiffusionBetaScheduler(None, None, n_denoise, beta_scheduler).schedule()
        
        # misc
        self.obs_dim = get_flattened_obs_dim(self.observation_space)
        self.act_dim = get_act_dim(self.action_space)
        self.out_dim = self.obs_dim + self.act_dim

        self._build()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()
        return data

    def _build_enc(self, batch_keys: Dict[str, jax.Array]) -> hk.Module:
        # skill encoder
        return LSTMEncoder(
            horizon=self.horizon,
            emb_dim=self.emb_dim,   # skill embedding dimension
            ctx_dim=self.ctx_dim,
            hid_dim=self.hid_dim,
            batch_keys=batch_keys,
            activation_fn=self.activation_fn,
        )

    def _build_dec(self, batch_keys: Dict[str, jax.Array]) -> hk.Module:
        # diffusion planner
        unet = UNetDiffusion(
            horizon=self.horizon,
            emb_dim=self.emb_dim,
            out_dim=self.out_dim,
            dim_mults=(1,4,8),
            attention=False,
            batch_keys=batch_keys,
            activation_fn=self.activation_fn
        )
        return Diffusion(
            diffusion=unet,
            n_denoise=self.n_denoise,
            ddpm_dict=self.ddpm_dict,
            guidance_weight=self.cf_weight,   # no guidance
            predict_epsilon=self.predict_epsilon,   # predict noise
            denoise_type='ddpm',
        )

    def _build(self) -> None:
        # dummy inputs
        dummy_obs, dummy_act = get_dummy_obs(self.observation_space), get_dummy_act(self.action_space)
        dummy_obs_stack = jnp.repeat(dummy_obs, self.horizon, axis=0).reshape(1, self.horizon, -1) # stacked observation
        dummy_act_stack = jnp.repeat(dummy_act, self.horizon, axis=0).reshape(1, self.horizon, -1) # stacked action
        dummy_traj = jnp.concatenate((dummy_obs_stack, dummy_act_stack), axis=-1)   # trajectory
        dummy_t = jnp.array([[1.]])

        def fn_act(x_t: jax.Array, batch_dict: Dict[str, jax.Array], t: jax.Array, ctx: jax.Array, denoise: bool, deterministic: bool):
            enc = self._build_enc(batch_keys=["obs", "act"])
            dec = self._build_dec(batch_keys=["ctx"])
            vae = DiffusionVAE(enc, dec)
            return vae(x_t, batch_dict, t, ctx, denoise, deterministic)
        params, self.pi = hk.transform(fn_act)
        enc_batch_dict = {"obs": dummy_obs_stack, "act": dummy_act_stack} 
        dec_batch_dict = {}
        batch_dict = {"enc": enc_batch_dict, "dec": dec_batch_dict}
        self.params = params(next(self.rng), dummy_traj, batch_dict, dummy_t, ctx=None, denoise=False, deterministic=False)
    
    @partial(jax.jit, static_argnums=(0,5,6))
    def _pi(
        self, 
        x_t: jax.Array, 
        batch_dict: Dict[str, jax.Array],
        t: jax.Array, 
        ctx: jax.Array,
        denoise: bool, 
        deterministic: bool, 
        params: hk.Params, 
        rng=None
    ) -> Tuple[Tuple[jax.Array], Dict[str, jax.Array]]:
        # return: (mean, std, eps), info
        return self.pi(params, rng, x_t, batch_dict, t, ctx, denoise, deterministic)
    
    # encoder prediction
    def _predict_enc(
        self,
        obs: jax.Array,
        act: jax.Array,
        deterministic: bool = False,
    ) -> Tuple[jax.Array]:
        enc_batch_dict = {"obs": obs, "act": act}
        batch_dict = {"enc": enc_batch_dict}
        (mean, std, _), info = self._pi(None, batch_dict, None, None, False, deterministic, self.params, next(self.rng))
        return mean, std

    # decoder prediction
    def _predict(
        self,
        x_t: jax.Array,
        skill: jax.Array,
        t: int,
        ctx: jax.Array,
        deterministic: bool = False,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        # return: eps, info
        dec_batch_dict = {"skill": skill}
        batch_dict = {"dec": dec_batch_dict}
        ts = jnp.full((x_t.shape[0], 1), t)
        (_, _, eps), info = self._pi(x_t, batch_dict, ts, ctx, False, deterministic, self.params, next(self.rng))
        return eps, info

    # one denoise timestep prediction with guidance
    @partial(jax.jit, static_argnums=(0,4))
    def _sample(
        self,
        x_t: jax.Array,
        eps: jax.Array,
        t: int,
        deterministic: bool,
        rng=None,
    ) -> jax.Array:
        batch_size = x_t.shape[0]
        # LINE 3: sample noise
        noise = jax.random.normal(rng, shape=(batch_size, self.horizon, self.out_dim)) if not deterministic else 0.
       
        # LINE 4: cacluate x_{t-1}
        if self.predict_epsilon: # noise prediction
            x_t = self.ddpm_dict.oneover_sqrta[t] * (x_t - self.ddpm_dict.ma_over_sqrtmab_inv[t] * eps) \
                    + self.ddpm_dict.sqrt_beta_t[t] * noise
        else: # original action prediction
            x_t = self.ddpm_dict.posterior_mean_coef1[t] * eps + self.ddpm_dict.posterior_mean_coef2[t] * x_t \
                    + jnp.exp(0.5 * self.ddpm_dict.posterior_log_beta[t]) * noise
        return x_t

    def _denoise_act(
        self,
        cond: jax.Array,
        mask: jax.Array,
        ctx: jax.Array,
        skill: jax.Array,
        delta: float = 0.1,
        guide_fn: Callable = None,  # guidance function
        n_guide_steps: int = 1,     # number of guidance steps
        deterministic: bool = True,
        verbose: bool = False,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        batch_size = cond.shape[0]

        # LINE 0: sample x_0
        x_t = jax.random.normal(next(self.rng), shape=(batch_size, self.horizon, self.out_dim))
        x_t = apply_mask(mask, x_t, cond)

        for t in range(self.n_denoise, 0, -1):
            # LINE 1: predict epsilon (= x_hat_0)
            eps, _ = self._predict(x_t, skill, t, ctx, deterministic)
            original_eps = eps

            # cfg sampling
            # LINE 2: apply gradiet
            if (guide_fn is not None) and (t <= self.n_denoise - 2):
                for _ in range(n_guide_steps):
                    # calculate gradient
                    grad, grad_info = calculate_grad(guide_fn, eps, self.obs_dim)
                    # gradient scaling
                    loss = grad_info['loss']
                    count = 0
                    retry = 0 

                    if loss < 0.0:
                        loss = -loss

                    while True:
                        retry += 1
                        if (loss <= 1.0 and loss >= 0.1) or loss == 0.0:
                            break
                        if loss > 1.0:
                            loss /= 10
                            count -= 1
                        if loss < 0.1:
                            loss *= 10
                            count += 1

                    try: 
                        grad = grad * (10 ** count)
                    except:
                        print(count)
                        print(grad_info['loss'])
                        exit()
                    # apply masking
                    grad = (1 - mask) * grad
                    # apply gradient
                    eps = eps - delta * grad  # jnp.exp(self.ddpm_dict.posterior_log_beta[t])
                    
                if verbose:
                    jnp.set_printoptions(precision=5, suppress=True)
                    print("="*30)
                    print(f"Loss: {grad_info['loss']}")
                    print(f"Scale at {t}: {jnp.exp(self.ddpm_dict.posterior_log_beta[t]):.3f}")
                    print(f"Original EPS: {original_eps[:,:2,-self.act_dim:]}")
                    print(f"GRD: {grad[:,:,-self.act_dim:]}")
                    print(f"EPS: {eps[:,:,-self.act_dim:]}")
            
            # LINE 3: sampling
            x_t = self._sample(x_t, eps, t, deterministic, next(self.rng))
            x_t = apply_mask(mask, x_t, cond)
        return x_t, {"grad": grad_info['loss'] if guide_fn is not None else None}

    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        print_b("[diffgro/actor]: loading params")
        self.params = params["pi_params"]


class Prior(BasePolicy):
    """ actor class for diffgro policy """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        net_arch: Optional[List[int]] = None,
        activation_fn: str = 'mish',
        domain: str = 'short',
        # embedding
        skill_dim: int = 512,   # semantic skill embedding dimension
        emb_dim: int = 64,      # embeddign dimension
        ctx_dim: int = 64,      # skill embedding dimension
        seed: int = 1,
    ):
        super(Prior, self).__init__(
            observation_space,
            action_space,
            squash_output=False,
            seed=seed,
        )

        self.net_arch = net_arch
        self.activation_fn = activation_fn

        self.domain = domain
        self.skill_dim = skill_dim
        self.emb_dim = emb_dim
        self.ctx_dim = ctx_dim

        self._build()

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()
        return data

    def _build_pri(self, batch_keys: Dict[str, jax.Array]) -> hk.Module:
        # skill prior
        return MLPEncoder(
            net_arch=self.net_arch,
            emb_dim=self.emb_dim,
            ctx_dim=self.ctx_dim,
            batch_keys=batch_keys,
            activation_fn=self.activation_fn,
        )
    
    def _build(self) -> None:
        dummy_task = jax.random.normal(next(self.rng), shape=(1, self.skill_dim))
        dummy_skill = jax.random.normal(next(self.rng), shape=(1, self.skill_dim))
        dummy_obs = get_dummy_obs(self.observation_space)

        def fn_pri(batch_dict: Dict[str, jax.Array]):
            batch_keys = ["obs", "task"] if self.domain == 'short' else ["obs", "task", "skill"]
            pri = self._build_pri(batch_keys=batch_keys)
            return pri(batch_dict)
        params, self.pr = hk.transform(fn_pri)
        batch_dict = {"obs": dummy_obs, "task": dummy_task, "skill": dummy_skill}
        self.params = params(next(self.rng), batch_dict)

    @partial(jax.jit, static_argnums=(0,))
    def _pr(
        self, 
        batch_dict: Dict[str, jax.Array], 
        params: hk.Params, 
        rng=None
    ) -> Tuple[jax.Array]:
        return self.pr(params, rng, batch_dict)

    def _predict(self, obs: jax.Array, task: jax.Array, skill: jax.Array = None, deterministic: bool = False) -> jax.Array:
        # return: ctx
        batch_dict = {"obs": obs, "task": task, "skill": skill}
        mean, std = self._pr(batch_dict, self.params, next(self.rng))
        return sample_dist(mean, std, deterministic)

    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        print_b("[diffgro/prior]: loading params")
        self.params = params["pr_params"]


class DiffGroPlannerPolicy(BasePolicy):
    """ policy class for diffgro """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[List[int]] = None,
        activation_fn: str = 'mish',
        domain: str = 'short',
        # embedding
        horizon: int = 8,       # planning horizon
        skill_dim: int = 512,   # semantic skill embedding dimension
        emb_dim: int = 64,      # embedding dimension
        hid_dim: int = 128,     # lstm hidden dimension
        ctx_dim: int = 64,      # skill embedding dimension
        # diffusion
        n_denoise: int = 20,    # denoising timestep
        cf_weight: float = 1.0, # diffusion classifier-free guidance weight
        predict_epsilon: bool = False,  # predict noise / original
        beta_scheduler: str = 'linear', # denoising scheduler
        # others
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        normalize_images: bool = True,
        optimizer_class: Callable = optax.adamw,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        normalization_class: Type[BaseNormLayer] = None,
        normalization_kwargs: Optional[Dict[str, Any]] = None,
        seed: int = 1,
    ):
        super(DiffGroPlannerPolicy, self).__init__(
            observation_space,
            action_space,
            features_extractor_class,
            features_extractor_kwargs,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            normalization_class=normalization_class,
            normalization_kwargs=normalization_kwargs,
            squash_output=squash_output,
            seed=seed,
        )

        if net_arch is None:
            net_arch = dict(act=(1,4,8), pri=[128,128])
        self.act_arch, self.pri_arch = net_arch['act'], net_arch['pri']
        self.activation_fn = activation_fn

        self.domain = domain
        assert self.domain in ['short', 'long'], 'Domain should be either short or long'
        self.horizon = horizon
        self.skill_dim = skill_dim
        self.emb_dim = emb_dim
        self.hid_dim = hid_dim
        self.ctx_dim = ctx_dim

        self.n_denoise = n_denoise
        self.cf_weight = cf_weight
        self.predict_epsilon = predict_epsilon
        self.beta_scheduler = beta_scheduler
        
        # construct args
        self.net_args = {
            "observation_space": self.observation_space,
            "action_space": self.action_space,
            "activation_fn": self.activation_fn,
            "seed": seed,
        }

        # actor kwargs
        self.act_kwargs = self.net_args.copy()
        self.act_kwargs.update({
            "net_arch": self.act_arch,
            "horizon": horizon,
            "skill_dim": skill_dim,
            "emb_dim": emb_dim,
            "hid_dim": hid_dim,
            "ctx_dim": ctx_dim,
            "n_denoise": n_denoise,
            "cf_weight": cf_weight,
            "predict_epsilon": predict_epsilon,
            "beta_scheduler": beta_scheduler,
        })

        # prior kwargs
        self.pri_kwargs = self.net_args.copy()
        self.pri_kwargs.update({
            "domain": domain,
            "net_arch": self.pri_arch,
            "emb_dim": emb_dim,
            "ctx_dim": ctx_dim,
        })

        self._build(lr_schedule)
    
    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                observation_space=self.observation_space,
                action_space=self.action_space, 
                horizon=self.horizon,
                skill_dim=self.skill_dim,
                emb_dim=self.emb_dim,
                hid_dim=self.hid_dim,
                ctx_dim=self.ctx_dim,
                n_denoise=self.n_denoise,
                cf_weight=self.cf_weight,
                predict_epsilon=self.predict_epsilon,
                beta_scheduler=self.beta_scheduler,
                optimizer_class=self.optimizer_class,
                optimizer_kwargs=self.optimizer_kwargs,
                features_extractor_class=self.features_extractor_class,
                features_extractor_kwargs=self.features_extractor_kwargs,
                normalization_class=self.normalization_class,
                normalization_kwargs=self.normalization_kwargs,
            )
        )
        return data

    def _build(self, lr_schedule: Tuple[float]) -> None:
        if self.normalization_class is not None:
            self.normalization_layer = self.normalization_class(self.observation_space.shape, **self.normalization_kwargs)
        
        # make actor
        self.act = self.make_act()
        self.act.optim = self.optimizer_class(learning_rate=lr_schedule, **self.optimizer_kwargs)
        self.act.optim_state = self.act.optim.init(self.act.params)

        # make prior
        self.pri = self.make_pri()
        self.pri.optim = self.optimizer_class(learning_rate=lr_schedule, **self.optimizer_kwargs)
        self.pri.optim_state = self.pri.optim.init(self.pri.params)

    def make_act(self) -> Actor:
        return Actor(**self.act_kwargs)
    
    def make_pri(self) -> Prior:
        return Prior(**self.pri_kwargs)
    
    def _predict_pri(self, obs: jax.Array, task: jax.Array, skill: jax.Array = None, deterministic: bool = True) -> jax.Array:
        obs = self.preprocess(obs, training=False)
        return self.pri._predict(obs, task, skill, deterministic)

    def _predict_enc(self, obs: jax.Array, act: jax.Array, deterministic: bool = True) -> Tuple[jax.Array]:
        obs = self.preprocess(obs.reshape(-1, self.act.obs_dim), training=False)
        obs = obs.reshape(-1, self.act.horizon, self.act.obs_dim)
        return self.act._predict_enc(obs, act, deterministic)

    def _predict_act(
        self, 
        cond: jax.Array,
        mask: jax.Array,
        ctx: jax.Array,
        skill: jax.Array, 
        delta: float = 1.0,
        guide_fn: jax.Array = None,
        n_guide_steps: int = 1,
        deterministic: bool = True,
        verbose: bool = False,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        # preprocess observation part
        obs, act = cond[:,:,:self.act.obs_dim], cond[:,:,-self.act.act_dim:]
        obs = self.preprocess(obs.reshape(-1, self.act.obs_dim), training=False)
        obs = obs.reshape(-1, self.act.horizon, self.act.obs_dim)
        cond = jnp.concatenate((obs, act), axis=-1)
        return self.act._denoise_act(cond, mask, ctx, skill, delta, guide_fn, n_guide_steps, deterministic, verbose)

    def _predict(self,):
        raise NotImplementedError

# ====================================================================================== #

class DiffGroPredictorPolicy(BasePolicy):
    """ skill termination predictor for diffgro """
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule: Schedule,
        net_arch: Optional[List[int]] = None,
        activation_fn: str = 'mish',
        # embedding
        horizon: int = 4,       # prediction horizon
        skill_dim: int = 512,   # skill embedding dimension
        emb_dim: int = 64,
        # others
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        normalize_images: bool = True,
        optimizer_class: Callable = optax.adamw,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        normalization_class: Type[BaseNormLayer] = None,
        normalization_kwargs: Optional[Dict[str, Any]] = None,
        seed: int = 1,
    ):
        super(DiffGroPredictorPolicy, self).__init__(
            observation_space,
            action_space,
            features_extractor_class,
            features_extractor_kwargs,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            normalization_class=normalization_class,
            normalization_kwargs=normalization_kwargs,
            squash_output=squash_output,
            seed=seed,
        )

        if net_arch is None:
            net_arch = [128,128]
        self.net_arch = net_arch
        self.activation_fn = activation_fn

        self.horizon = horizon
        self.emb_dim = emb_dim
        self.skill_dim = skill_dim

        self.obs_dim = get_flattened_obs_dim(self.observation_space)
        self.act_dim = get_act_dim(self.action_space)

        self._build(lr_schedule)

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                observation_space=self.observation_space,
                action_space=self.action_space, 
                horizon=self.horizon,
                emb_dim=self.emb_dim,
                skill_dim=self.skill_dim,
                optimizer_class=self.optimizer_class,
                optimizer_kwargs=self.optimizer_kwargs,
                features_extractor_class=self.features_extractor_class,
                features_extractor_kwargs=self.features_extractor_kwargs,
                normalization_class=self.normalization_class,
                normalization_kwargs=self.normalization_kwargs,
            )
        )
        return data

    def _build_prd(self, batch_keys: Dict[str, jax.Array]) -> hk.Module:
        return MLP(
            emb_dim=self.emb_dim,
            out_dim=1,
            net_arch=self.net_arch,
            batch_keys=batch_keys,
            activation_fn=self.activation_fn,
            squash_output=False,
        )

    def _build(self, lr_schedule: Tuple[float]) -> None:
        if self.normalization_class is not None:
            self.normalization_layer = self.normalization_class(self.observation_space.shape, **self.normalization_kwargs)
        
        dummy_obs = get_dummy_obs(self.observation_space)
        dummy_obs_repeat = dummy_obs.repeat(self.horizon, axis=1)
        dummy_act = get_dummy_act(self.action_space)
        dummy_act_repeat = dummy_act.repeat(self.horizon, axis=1)
        dummy_skill = jax.random.normal(next(self.rng), shape=(1, self.skill_dim))
        
        def fn_prd(batch_dict: Dict[str, jax.Array]):
            prd = self._build_prd(batch_keys=["obs_0", "obs", "skill"])
            return jax.nn.sigmoid(prd(batch_dict)) # sigmoid
        params, self.pd = hk.transform(fn_prd)
        batch_dict = {"obs_0": dummy_obs, "obs": dummy_obs_repeat, "act": dummy_act_repeat, "skill": dummy_skill}
        self.params = params(next(self.rng), batch_dict)

        self.optim = self.optimizer_class(learning_rate=lr_schedule, **self.optimizer_kwargs)
        self.optim_state = self.optim.init(self.params)

    @partial(jax.jit, static_argnums=(0))
    def _pd(
        self,
        batch_dict: Dict[str, jax.Array],
        params: hk.Params,
        rng=None,
    ) -> jax.Array:
        return self.pd(params, rng, batch_dict)
    
    def _predict(
        self,
        obs_0: jax.Array,
        obs: jax.Array,
        act: jax.Array,
        skill: jax.Array,
    ) -> jax.Array:
        batch_size = obs.shape[0]
        obs_0 = self.preprocess(obs_0, training=False)
        obs = self.preprocess(obs.reshape(-1, self.obs_dim), training=False)
        obs = obs.reshape(batch_size, -1)
        batch_dict = {"obs_0": obs_0, "obs": obs, "act": act, "skill": skill}
        return self._pd(batch_dict, self.params, next(self.rng))

    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        print_b("[diffgro/predictor]: loading params")
        self.params = params["pd_params"]
